"""
## LaVCa step 2: Identify the optimal image set by finding the top-N images 
## that most strongly activate each voxel (according to the trained encoding models).

[Image]
python -m LaVCa.step2_search_optimal_images \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default None \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_path_for_view ./data/OpenImages/frames_518x518px \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --device cpu

"""

import os
import numpy as np
from tqdm import tqdm
import shutil
import argparse
from utils.utils import search_best_layer, make_filename, load_frames, create_volume_index_and_weight_map
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
import cortex
from utils.nsd_access import NSDAccess


def save_npy(save_path, data):
    if os.path.exists(save_path):
        return
    np.save(save_path, data)
    
def copy_file_to_directory(src_path, dest_dir, new_name):
    os.makedirs(dest_dir, exist_ok=True)
    dest_path = os.path.join(dest_dir, new_name)
    if os.path.exists(src_path):
        shutil.copy(src_path, dest_path)


def load_and_prepare_dir(dir_name, frame_paths, stim_root_path, dataset_name, subject_name, modality, modality_hparam, model_name, best_layer, reduce_dims):
    dir_stims = []
    mov_paths = []
    for frame_path in frame_paths:
        movname = os.path.basename(frame_path).replace(".txt", "").replace(".mp4", "").replace(".png", "").replace(".jpg", "").replace(".wav", "").replace(".mp3", "")
        stim_best_layer_path = os.path.join(
            stim_root_path, dataset_name, modality, modality_hparam, 
            model_name, best_layer
        )

        if reduce_dims[0] == "default":
            stim = np.load(f"{stim_best_layer_path}/{dir_name}/{movname}.npy")
            
        else:
            filename = make_filename(reduce_dims)
            if os.path.exists(f"{stim_best_layer_path}/{dir_name}/{movname}_{subject_name}_ave_{filename}.npy"):
                stim = np.load(f"{stim_best_layer_path}/{dir_name}/{movname}_{subject_name}_ave_{filename}.npy")
            else:
                stim = np.load(f"{stim_best_layer_path}/{dir_name}/{movname}.npy")
            
        if stim is not None:
            dir_stims.append(stim)
            # mov_path = f"{stim_best_layer_path}/{dir_name}/{movname_stim}"
            mov_paths.append(f"{dir_name}/{movname}")

    if dir_stims:
        dir_stims = np.array(dir_stims)
        if len(dir_stims.shape) == 3:
            dir_stims = dir_stims.squeeze()
        
        return dir_stims, mov_paths
    
    return None, None


def search_optimal_imgs(volume_index, weight_index, subject_name, dataset_name, args, modality, modality_hparam, model_name, expanded_stims, 
                  layer_weight, all_movnames):
    print(f"Volume_index: {volume_index}")
    volume_index_pad = str(volume_index).zfill(6)
    resp_save_path = f"./data/nsd/insilico/{subject_name}/{dataset_name}_{args.max_samples}/{modality}/{modality_hparam}/{model_name}_{make_filename(args.reduce_dims[0:2])}/whole/voxel{volume_index_pad}"
    os.makedirs(resp_save_path, exist_ok=True)
    
    try:
        
        # For parallel processing
        temp_file_path = os.path.join(resp_save_path, f"temp_insilico.tmp")
        if os.path.exists(temp_file_path):
            print(f"Simulation for {volume_index_pad} is being processed.")
            return
    
        print(f"Now processing: {volume_index_pad}")
        open(temp_file_path, 'a').close()

        if os.path.exists(f"{resp_save_path}/resp_dict.npy"):
            try:
                np.load(f"{resp_save_path}/resp_dict.npy", allow_pickle=True).item()
                print(f"Already processed: {resp_save_path}")
                return
            except:
                pass

        voxel_weight = layer_weight[:, weight_index]
                
        os.makedirs(resp_save_path, exist_ok=True)
        voxel_resp_dict = {}
        
        print(expanded_stims.shape, voxel_weight.shape)

        expanded_stims = torch.tensor(expanded_stims, dtype=torch.float32)
        voxel_weight = torch.tensor(voxel_weight, dtype=torch.float32)
        resp_val_pred_all = expanded_stims @ voxel_weight.squeeze()
        resp_val_pred_all = resp_val_pred_all.numpy()
        print(f"resp_val_pred_all shape: {resp_val_pred_all.shape}")

        voxel_resp_dict = dict(zip(all_movnames, resp_val_pred_all))
        np.save(f"{resp_save_path}/resp_dict.npy", voxel_resp_dict)

        sorted_movies = sorted(voxel_resp_dict.items(), key=lambda x: x[1], reverse=True)
        top_100_movname = [key for key, value in sorted_movies[:100]]

        mov_copy_dir = f"{resp_save_path}/stim_top100"
        os.makedirs(mov_copy_dir, exist_ok=True)

        for topn, topn_movname in enumerate(top_100_movname):
            topn_movname_save = topn_movname.replace("/", "_")
            mov_path = f"{args.dataset_path_for_view}/{topn_movname}.jpg"
            if not os.path.exists(mov_path):
                mov_path = f"{args.dataset_path_for_view}/{topn_movname}.png"
            new_mov_name = f"top{topn+1}_{topn_movname_save}.jpg" if os.path.exists(mov_path.replace('.png', '.jpg')) else f"top{topn+1}_{topn_movname_save}.png"

            if os.path.exists(mov_path):
                copy_file_to_directory(mov_path, mov_copy_dir, new_mov_name)
    finally:
        try:
            os.remove(temp_file_path)
        except:
            pass

def reduce_dimensions(dir_stims, reducer_projector):
    # チャンクの処理を一括して次元削減
    dir_stims_transformed = reducer_projector.transform(dir_stims)
    return dir_stims_transformed

def main(args):
    score_root_path = "./data/nsd/encoding"
    modality = args.modality
    modality_hparam = args.modality_hparam
    model_name = args.model_name
    file_type = args.voxel_selection[0]
    threshold = float(args.voxel_selection[1])
    nsda = NSDAccess('./data/NSD')

    for subject_name in args.subject_names:
        print(subject_name)
        filename = make_filename(args.reduce_dims[0:2])

        print(f"Modality: {modality}, Modality hparams: {modality_hparam}, Feature: {model_name}, Filename: {filename}")
        # loading the selected layer per subject
        model_score_dir = f"{score_root_path}/{subject_name}/scores/{modality}/{modality_hparam}/{model_name}"
        if args.layer_selection == "best":
            target_best_cv_layer, _, _ = search_best_layer(model_score_dir, filename, select_topN="all")
        else:
            target_best_cv_layer = args.layer_selection
        print(f"Best layer: {target_best_cv_layer}")

        
        # Get encoding weight
        layer_path = f"{model_score_dir}/{target_best_cv_layer}"

        if args.device == "cuda":
            import cupy as cp
            layer_weight = cp.load(f"{layer_path}/coef_{filename}.npy")
        else:
            layer_weight = np.load(f"{layer_path}/coef_{filename}.npy")
        print(f"Shape of the layer's weight: {layer_weight.shape}")


        volume_index, weight_index_map, target_top_voxels = create_volume_index_and_weight_map(
            subject_name=subject_name,
            file_type=file_type,
            threshold=threshold,
            model_score_dir=model_score_dir,
            target_best_cv_layer=target_best_cv_layer,
            filename=filename,
            nsda=nsda,
            atlasnames=args.atlasname  # args.atlasname がリストであることを想定
        )
        
        stim_root_path = "./data/stim_features/nsd"
        if args.reduce_dims[0] != "default":
            try:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.npy"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True).item()
            except:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.pkl"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True)
        else:
            reducer_projector = None
        print(reducer_projector)

        dataset_name = args.dataset_name

        if args.max_samples == "full":
            break_point = 100000000
        else:
            break_point = int(args.max_samples)

        frames_all = load_frames(f"{args.dataset_path}", dataset_name)
        print(f"Number of directory: {len(frames_all)}")
        
        # Load stimulus
        count = 0
        all_stims = []
        all_mov_paths = []
        batch_size = 20  # バッチサイズを設定
        
        insilico_stim_root_path = f"./data/stim_features/"
        stim_best_layer_path = os.path.join(
            insilico_stim_root_path, dataset_name, modality, modality_hparam, 
            model_name, target_best_cv_layer
        )
        # 全てのdir_nameのファイルをバッチごとに処理
        dir_names = list(frames_all.keys())
        for i in range(0, len(dir_names), batch_size):
            batch_dir_names = dir_names[i:i + batch_size]
            
            with ThreadPoolExecutor() as load_executor:
                future_to_dirname = {load_executor.submit(load_and_prepare_dir, dir_name, frames_all[dir_name], insilico_stim_root_path, dataset_name, subject_name, modality, modality_hparam, model_name, target_best_cv_layer, args.reduce_dims): dir_name for dir_name in batch_dir_names}
                
                loaded_data = []
                for future in tqdm(as_completed(future_to_dirname), total=len(future_to_dirname)):
                    dir_stims, mov_paths = future.result()
                    if dir_stims is not None:
                        loaded_data.append((dir_stims, mov_paths))
                        count += len(mov_paths)
                        if count >= break_point:
                            break
            
            # Main processing code
            if args.reduce_dims[0] != "default":

                for dir_stims, mov_paths in tqdm(loaded_data):
                    if dir_stims.shape[1] != int(args.reduce_dims[1]):
                        dir_stims_transformed = reduce_dimensions(dir_stims, reducer_projector)

                        # Use ThreadPoolExecutor for parallel saving
                        with ThreadPoolExecutor() as executor:
                            futures = []
                            for idx, movname in enumerate(mov_paths):
                                save_path = f"{stim_best_layer_path}/{movname}_{subject_name}_ave_{filename}.npy"
                                futures.append(executor.submit(save_npy, save_path, dir_stims_transformed[idx]))

                            # Optional: ensure all futures are completed (for error handling)
                            for future in tqdm(futures, desc="Saving files"):
                                future.result()
                    else:
                        dir_stims_transformed = dir_stims
                        
                    all_stims.append(dir_stims_transformed)
                    all_mov_paths.extend(mov_paths)
            else:
                for dir_stims, mov_paths in loaded_data:
                    all_stims.append(dir_stims)
                    all_mov_paths.extend(mov_paths)

            if count >= break_point:
                break

        if all_stims:
            all_stims = np.concatenate(all_stims, axis=0)
        
        all_stims = np.array(all_stims).squeeze()
        print(f"Reduced features shape: {all_stims.shape}")
        if args.device == "cuda":
            all_stims = cp.asarray(all_stims)

        # Search optimal images
        for vol_idx in volume_index:
            weight_index = weight_index_map[vol_idx]
            search_optimal_imgs(vol_idx, weight_index, subject_name, dataset_name, args, modality, modality_hparam, model_name, 
                          all_stims, layer_weight, all_mov_paths)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Add your arguments here
    parser.add_argument(
        "--subject_names",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--atlasname",
        type=str,
        nargs="*",
        required=True,
    )
    parser.add_argument(
        "--modality",
        type=str,
        required=True,
        help="Name of the modality to use."
    )
    parser.add_argument(
        "--modality_hparam",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--reduce_dims",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_samples",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path_for_view",
        type=str,
        required=True,
        help="Path to the dataset for viewing the top 100 movies."
    )
    parser.add_argument(
        "--voxel_selection",
        nargs="*",
        type=str,
        required=True,
        help="Selection method of voxels. Implemented type are 'uv' and 'share'."
    )
    parser.add_argument(
        "--layer_selection",
        type=str,
        required=False,
        default="best",
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        choices=["cuda", "cpu"],
        help="Device to use."
    )
    args = parser.parse_args()
    main(args)
